package com.gitee.jmash.rbac.client.shiro.grpc;

import com.gitee.jmash.common.grpc.GrpcContext;
import com.gitee.jmash.common.grpc.GrpcMetadata;
import com.gitee.jmash.common.security.DefaultJmashPrincipal;
import com.gitee.jmash.common.security.JmashPrincipal;
import com.gitee.jmash.core.grpc.cdi.GrpcServerInterceptor;
import com.gitee.jmash.core.utils.TenantUtil;
import com.gitee.jmash.rbac.client.shiro.authc.JmashShiroJwtToken;
import com.gitee.jmash.rbac.client.token.OrganUserAccessToken;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.smallrye.jwt.auth.principal.JWTParser;
import jakarta.annotation.Priority;
import jakarta.enterprise.inject.spi.CDI;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.session.ExpiredSessionException;
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.DefaultSessionContext;
import org.apache.shiro.session.mgt.DefaultSessionKey;
import org.apache.shiro.session.mgt.SessionContext;
import org.apache.shiro.session.mgt.SessionKey;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.subject.SubjectContext;
import org.apache.shiro.subject.support.DefaultSubjectContext;
import org.eclipse.microprofile.jwt.JsonWebToken;

/**
 * OAuth拦截器
 * 
 * @author CGD
 *
 */
@GrpcServerInterceptor
@Priority(1000)
public class ShiroAuthInterceptor implements ServerInterceptor {

  private static final Log logger = LogFactory.getLog(ShiroAuthInterceptor.class);

  /** 直接使用 Subject subject = (Subject) GrpcContext.USER_SUBJECT.get(); */
  @SuppressWarnings("unchecked")
  private static final Context.Key<Subject> USER_SUBJECT =
      (Context.Key<Subject>) GrpcContext.USER_SUBJECT;

  public ShiroAuthInterceptor() throws Exception {
    super();
  }

  public JWTParser getJWTParser() {
    return CDI.current().select(JWTParser.class).get();
  }

  @Override
  public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
      ServerCallHandler<ReqT, RespT> next) {
    Context context = Context.current();
    String tenant = headers.get(GrpcMetadata.TENANT_KEY);
    if (StringUtils.isNotBlank(tenant)) {
      context = context.withValue(GrpcContext.TENANT, tenant);
    }
    String authorization = headers.get(GrpcMetadata.AUTH_KEY);
    // 无认证信息进入权限验证环节
    if (StringUtils.isBlank(authorization)
        || authorization.length() <= GrpcMetadata.BEARER.length() + 1
        || !authorization.startsWith(GrpcMetadata.BEARER)) {
      if (StringUtils.isNotBlank(authorization)) {
        context = context.withValue(GrpcContext.OTHER_AUTH, authorization);
      }
      return Contexts.interceptCall(context, call, headers, next);
    }

    String accessToken = authorization.substring(GrpcMetadata.BEARER.length() + 1).trim();
    JmashPrincipal jmashPrincipal = parseJwt(call, headers, accessToken);
    if (jmashPrincipal == null) {
      return new ServerCall.Listener<ReqT>() {};
    }

    // 组织租户，包含组织ID信息 organID@e01。
    if (TenantUtil.hasTenantIdentifier(tenant)
        && !StringUtils.equals(jmashPrincipal.getTenant(), tenant)) {
      accessToken = OrganUserAccessToken.getAccessToken(accessToken, tenant, jmashPrincipal);
      jmashPrincipal = parseJwt(call, headers, accessToken);
      if (jmashPrincipal == null) {
        return new ServerCall.Listener<ReqT>() {};
      }
    } else if (StringUtils.isNotBlank(tenant)
        && TenantUtil.hasTenantIdentifier(jmashPrincipal.getTenant())
        && !StringUtils.equals(jmashPrincipal.getTenant(), tenant)) {
      // 源AccessToken
      accessToken = OrganUserAccessToken.getOriginAccessToken(tenant, jmashPrincipal);
      jmashPrincipal = parseJwt(call, headers, accessToken);
      if (jmashPrincipal == null) {
        return new ServerCall.Listener<ReqT>() {};
      }
    }

    Subject subject = getSubject(jmashPrincipal.getTenant(), jmashPrincipal.getName(),
        jmashPrincipal.getClientId());
    if (!subject.isAuthenticated()) {
      JmashShiroJwtToken token = new JmashShiroJwtToken(accessToken);
      subject.login(token);
    }
    if (subject.isAuthenticated()) {
      printSubject(subject, tenant);
      context = context.withValue(GrpcContext.USER_AUTH, accessToken)
          .withValue(GrpcContext.USER_TOKEN, (JsonWebToken) subject.getPrincipal())
          .withValue(USER_SUBJECT, subject);
      return Contexts.interceptCall(context, call, headers, next);
    } else {
      logger.error("401 Unauthorized ");
      Status status = Status.UNAUTHENTICATED.withDescription("401 Unauthorized Client Connection.");
      call.close(status, headers);
      return new ServerCall.Listener<ReqT>() {};
    }
  }

  protected <ReqT, RespT> JmashPrincipal parseJwt(ServerCall<ReqT, RespT> call, Metadata headers,
      String userAuthToken) {
    JsonWebToken webToken = null;
    try {
      webToken = getJWTParser().parse(userAuthToken);
      return DefaultJmashPrincipal.create(webToken);
    } catch (Exception ex) {
      logger.error("JWT ERROR 401 Error Authorization :" + ex.getMessage());
      Status status =
          Status.UNAUTHENTICATED.withDescription(" 401 Error Authorization Client Connection.");
      call.close(status, headers);
    }
    return null;
  }

  /** 登录主体信息日志输出. */
  public void printSubject(Subject subject, String organTenant) {
    JsonWebToken p = (JsonWebToken) subject.getPrincipal();
    StringBuilder s = new StringBuilder(" IP:" + GrpcContext.USER_IP.get() + ", ");
    s.append(String.format("UserId:%s, UserName:%s,Token Tenant:%s , Organ Tenant: %s", p.getName(),
        p.getSubject(), p.getIssuer(), organTenant));
    if (subject.isRunAs()) {
      s.append("runAs:true, ");
      JsonWebToken pre = (JsonWebToken) subject.getPreviousPrincipals().getPrimaryPrincipal();
      s.append(String.format("PreviousUserId:%s, PreviousUserName:%s, ", pre.getName(),
          pre.getSubject()));
    }
    logger.info(s);
  }

  /** 获取用户Shiro Subject. */
  public Subject getSubject(String tenant, String userId, String clientId) {
    String sessionId = tenant + ":" + clientId + ":" + userId;
    try {
      // 恢复Session
      SessionKey key = new DefaultSessionKey(sessionId);
      Session  session = SecurityUtils.getSecurityManager().getSession(key);
      SubjectContext sc = new DefaultSubjectContext();
      sc.setSession(session);
      Subject subject = SecurityUtils.getSecurityManager().createSubject(sc);
      return subject;
    } catch (UnknownSessionException | ExpiredSessionException ex) {
      // 创建新Session
      SessionContext context = new DefaultSessionContext();
      context.setSessionId(sessionId);
      Session  session = SecurityUtils.getSecurityManager().start(context);
      SubjectContext sc = new DefaultSubjectContext();
      sc.setSession(session);
      Subject subject = SecurityUtils.getSecurityManager().createSubject(sc);
      return subject;
    }
  }
}
